from pathlib import Path
from typing import Dict, List, Tuple

import pandas as pd
from models import MODELS_WIDTH1024 as MODELS

TASKS: Dict[str, Tuple[str, List[int]]] = {
    "gsm8k_0shot_temp0.2": ("gsm8k0shot", [1]),
    # "MATH_0shot_temp0.0_both": ("MATH0shot_both", [1]),
    "MATH_0shot_temp0.2_boxed": ("MATH0shot", [1]),
}

RESULT_ROOT = Path("results")

METHODS: Dict[str, str] = {
    "coverage": "coverage",
    "majority vote": "majority vote",
    "most rewarded": "reward + best",
    "reward weighted": "reward + majority",
}


def get_model_info(model_path: str) -> Tuple[int, int, int]:
    if model_path not in MODELS:
        raise ValueError(f"Model path {model_path} not found in MODELS.")
    return MODELS[model_path]


def get_model_result_dir(model_path: str) -> Path:
    res = RESULT_ROOT
    split_path = model_path.split("/")
    for part in split_path:
        res = res / part
    return res


def get_task_result(
    result_path: Path,
    method: str,
    n_gens: List[int],
) -> List[float]:
    df = pd.read_csv(result_path)
    df = df[df["samples"] == f"{method}"]
    assert type(df) == pd.DataFrame, f"Expected DataFrame, got {type(df)}"
    result: List[float] = []
    for n_gen in n_gens:
        df_ngen = df[f"{n_gen}"]
        assert type(df_ngen) == pd.Series, f"Expected Series, got {type(df_ngen)}"
        val = df_ngen.values
        assert len(val) == 1, f"Expected one value, got {len(val)}"
        result.append(val[0])
    return result


def get_initial_result_dict(model_path: str) -> Dict:
    width, n_experts, topk = get_model_info(model_path)
    result_dict: Dict = {}
    result_dict["exp"] = n_experts
    result_dict["topk"] = topk
    result_dict["width"] = width
    return result_dict


def main():
    all_results: List[Dict] = []
    for model_path, (width, n_experts, topk) in MODELS.items():
        result_dir = get_model_result_dir(model_path)
        assert result_dir.exists(), f"Result directory {result_dir} does not exist."
        for method, method_column_name in METHODS.items():
            model_results = get_initial_result_dict(model_path)
            model_results["accuracy"] = method_column_name
            for task, (column_name, n_gens) in TASKS.items():
                task_result_file = result_dir / f"{task}.csv"
                if not task_result_file.exists():
                    print(f"Task result file {task_result_file} does not exist.")
                    continue
                task_result = get_task_result(
                    task_result_file,
                    method,
                    n_gens,
                )
                for i, n_gen in enumerate(n_gens):
                    model_results[f"{column_name}-{n_gen}"] = task_result[i]
            all_results.append(model_results)
    all_results_df = pd.DataFrame(all_results)
    all_results_df = all_results_df.set_index(["exp", "topk", "width", "accuracy"])
    all_results_df = all_results_df.sort_index()
    all_results_df.to_csv("results/summary.csv")
    print(all_results_df)
    print("Results saved to results/summary.csv")


if __name__ == "__main__":
    main()
